#!/usr/bin/env python3
# G20_v2 — Boundary Masks (+2/+3)
# CONTROL: present-act, boolean/ordinal. 1/r DDA per shell — no curves, no weights, no RNG in control.
# PANELS on the same grid:
#   BASE: no mask
#   MASK: isotropic boundary mask (checker or frame) applied to cell set (diagnostics-only geometry)
# READOUTS (diagnostics-only):
#   • Log-slope (mid-60%) and R^2
#   • Plateau CV on equal-Δr bins (outer fraction)
#   • Amplitude ratio MASK/BASE vs predicted fraction from mask geometry
#   • Azimuth flatness (relative RMSE across S sectors) over a common interior window

import argparse, csv, hashlib, json, math, os, sys
from datetime import datetime, timezone
from typing import Dict, List, Tuple

def utc_timestamp() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")

def ensure_dirs(root, subs):
    for s in subs: os.makedirs(os.path.join(root, s), exist_ok=True)

def write_text(path, txt):
    with open(path, "w", encoding="utf-8") as f: f.write(txt)

def json_dump(path, obj):
    with open(path, "w", encoding="utf-8") as f: json.dump(obj, f, indent=2, sort_keys=True)

def sha256_file(path):
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1<<20), b""): h.update(chunk)
    return h.hexdigest()

def isqrt(n: int) -> int:
    return int(math.isqrt(n))

def sector_index(x: int, y: int, cx: int, cy: int, S: int) -> int:
    ang = math.atan2(y - cy, x - cx)
    if ang < 0: ang += 2.0*math.pi
    s = int((ang/(2.0*math.pi))*S)
    return s if s < S else S-1

def keep_cell_mask(mask: dict, x: int, y: int, N: int) -> bool:
    mtype = mask.get("type", "none")
    if mtype == "none":
        return True
    if mtype == "checker":
        M = int(mask.get("period", 6))
        K = int(mask.get("keep", 3))       # predicted frac = K/M
        return ((x + y) % M) < K
    if mtype == "frame":
        w = int(mask.get("frame_width", 8))
        return (w <= x < N - w) and (w <= y < N - w)
    return True

def build_counts(N:int, cx:int, cy:int, S:int, mask:dict) -> Tuple[Dict[int,int], Dict[int,List[int]], int]:
    shell_counts: Dict[int,int] = {}
    shell_sector_counts: Dict[int,List[int]] = {}
    for y in range(N):
        for x in range(N):
            if not keep_cell_mask(mask, x, y, N):
                continue
            r = isqrt((x - cx)*(x - cx) + (y - cy)*(y - cy))
            s = sector_index(x, y, cx, cy, S)
            shell_counts[r] = shell_counts.get(r, 0) + 1
            if r not in shell_sector_counts:
                shell_sector_counts[r] = [0]*S
            shell_sector_counts[r][s] += 1
    R_edge = min(cx, cy, (N-1)-cx, (N-1)-cy)
    return shell_counts, shell_sector_counts, R_edge

def simulate_dda(shell_counts: Dict[int,int], H:int, rate_num:int) -> Dict[int,int]:
    A = {r: 0 for r in shell_counts}
    F = {r: 0 for r in shell_counts}
    for _ in range(H):
        for r in shell_counts.keys():
            if r == 0:  # skip center in control/readouts
                continue
            A[r] += rate_num
            if A[r] >= r:
                F[r] += 1
                A[r] -= r
    return F

def linreg_y_on_x(xs, ys):
    n = len(xs)
    if n < 2: return float("nan"), float("nan"), 0
    xb = sum(xs)/n; yb = sum(ys)/n
    num = sum((x-xb)*(y-yb) for x,y in zip(xs,ys))
    den = sum((x-xb)*(x-xb) for x in xs)
    if den == 0: return float("nan"), float("nan"), n
    b = num/den
    a = yb - b*xb
    ss_tot = sum((y-yb)*(y-yb) for y in ys)
    ss_res = sum((y-(a+b*x))*(y-(a+b*x)) for x,y in zip(xs,ys))
    r2 = 1.0 - (ss_res/ss_tot if ss_tot>0 else 0.0)
    return b, r2, n

def build_log_edges(r_lo:int, r_hi:int, n_bins:int) -> List[float]:
    log_lo, log_hi = math.log(max(1, r_lo)), math.log(max(1, r_hi))
    return [math.exp(log_lo + (log_hi-log_lo)*i/n_bins) for i in range(n_bins+1)]

def shells_in_range_int(rs: List[int], lo: float, hi: float) -> List[int]:
    lo_i = math.ceil(lo); hi_i = math.floor(hi)
    return [r for r in rs if lo_i <= r <= hi_i]

def slope_panel(shell_counts, fires, H, r_min, r_max, log_edges, fit_mid_frac):
    rs = [r for r in sorted(shell_counts.keys()) if r_min <= r <= r_max and r>0]
    X, Y = [], []
    for i in range(len(log_edges)-1):
        lo, hi = log_edges[i], log_edges[i+1]
        arr = shells_in_range_int(rs, lo, hi)
        if not arr: continue
        rates = [(fires.get(r,0)/H) for r in arr]
        r_rep = math.exp((math.log(arr[0]) + math.log(arr[-1]))/2.0)
        X.append(math.log((rs[-1])/r_rep + 1e-12))
        Y.append(math.log(sum(rates)/len(rates) + 1e-12))
    k = len(X)
    if k < 4: return {"slope": float("nan"), "r2": float("nan"), "bins": 0}
    m = int(round(k*(1.0-fit_mid_frac)/2.0))
    useX = X[m:k-m] if k-2*m >= 2 else X
    useY = Y[m:k-m] if k-2*m >= 2 else Y
    slope, r2, _ = linreg_y_on_x(useX, useY)
    return {"slope": slope, "r2": r2, "bins": k}

def plateau_panel(shell_counts, fires, H, r_min, r_max, shells_per_bin, outer_frac):
    if r_max <= r_min + shells_per_bin:
        return {"cv": float("nan"), "amp_mean": float("nan"), "nbins": 0}
    width = shells_per_bin
    stop = r_max - ((r_max - r_min + 1) % width)
    if stop < r_min + width - 1:
        return {"cv": float("nan"), "amp_mean": float("nan"), "nbins": 0}
    bins = []
    r = r_min
    while r + width - 1 <= stop:
        v = 0.0
        for rr in range(r, r+width):
            v += shell_counts.get(rr,0) * (fires.get(rr,0)/H)
        bins.append(v)
        r += width
    if not bins:
        return {"cv": float("nan"), "amp_mean": float("nan"), "nbins": 0}
    k = len(bins)
    take = max(1, int(round(k * outer_frac)))
    outer = bins[-take:]
    mu = sum(outer)/len(outer)
    if mu == 0.0:
        return {"cv": float("inf"), "amp_mean": 0.0, "nbins": k}
    s2 = sum((v-mu)*(v-mu) for v in outer)/len(outer)
    cv = math.sqrt(s2)/mu
    return {"cv": cv, "amp_mean": mu, "nbins": k}

def azimuth_flatness(shell_sector_counts, fires, H, r_min, r_max, S):
    per_sector = [0.0]*S
    for r in range(r_min, r_max+1):
        if r not in shell_sector_counts: continue
        rate = fires.get(r,0)/H
        row = shell_sector_counts[r]
        for s in range(S):
            per_sector[s] += row[s] * rate
    mu = sum(per_sector)/S
    if mu == 0.0: return float("inf")
    s2 = sum((v-mu)*(v-mu) for v in per_sector)/S
    return math.sqrt(s2)/mu

def run_panel(N, cx, cy, S, mask, H, rate, r_min_slope, r_min_plat, r_max_glob, slope_cfg, plat_cfg):
    shell_counts, shell_sector_counts, _ = build_counts(N, cx, cy, S, mask)
    fires = simulate_dda(shell_counts, H, rate)
    # slope on shared log-edges
    log_edges = build_log_edges(r_min_slope, r_max_glob, int(slope_cfg["n_log_bins"]))
    slope = slope_panel(shell_counts, fires, H, r_min_slope, r_max_glob, log_edges, float(slope_cfg["fit_mid_frac"]))
    # plateau CV
    plat = plateau_panel(shell_counts, fires, H, r_min_plat, r_max_glob,
                         int(plat_cfg["shells_per_bin"]), float(plat_cfg["outer_frac"]))
    # azimuth flatness
    az = azimuth_flatness(shell_sector_counts, fires, H, r_min_plat, r_max_glob, S)
    return {"slope": slope["slope"], "r2": slope["r2"], "cv": plat["cv"], "amp_mean": plat["amp_mean"], "az_flat": az}

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--outdir", required=True)
    args = ap.parse_args()

    root = os.path.abspath(args.outdir)
    ensure_dirs(root, ["config","outputs/metrics","outputs/audits","outputs/run_info","logs"])

    with open(args.manifest, "r", encoding="utf-8") as f:
        M = json.load(f)
    mpath = os.path.join(root, "config", "manifest_g20_v2.json")
    json_dump(mpath, M)

    write_text(os.path.join(root, "logs", "env.txt"),
               "\\n".join([f"utc={utc_timestamp()}",
                          f"os={os.name}", f"cwd={os.getcwd()}",
                          f"python={sys.version.split()[0]}"]))

    N   = int(M["grid"]["N"])
    cx  = int(M["grid"]["cx"])
    cy  = int(M["grid"]["cy"])
    S   = int(M["sectors"]["S"])
    H   = int(M["H"])
    rn  = int(M["rate_num"])
    om  = int(M["outer_margin"])

    slope_cfg = M["slope"]; plat_cfg = M["plateau"]

    # common interior r_max
    R_edge = min(cx, cy, (N-1)-cx, (N-1)-cy)
    r_max_glob = R_edge - om
    if r_max_glob <= 0:
        raise RuntimeError("Invalid r_max_glob; increase outer_margin or grid size.")

    base_mask = {"type": "none"}
    mask_cfg  = dict(M["mask"]); mask_cfg["N"] = N

    # run panels
    base = run_panel(N, cx, cy, S, base_mask, H, rn,
                     int(slope_cfg.get("r_min", 8)),
                     int(plat_cfg["r_min"]),
                     r_max_glob, slope_cfg, plat_cfg)
    masked = run_panel(N, cx, cy, S, mask_cfg, H, rn,
                       int(slope_cfg.get("r_min", 8)),
                       int(plat_cfg["r_min"]),
                       r_max_glob, slope_cfg, plat_cfg)

    # predicted amplitude fraction
    pred_frac = 1.0
    if mask_cfg.get("type") == "checker":
        pred_frac = float(mask_cfg.get("keep", 3)) / float(mask_cfg.get("period", 6))
    elif mask_cfg.get("type") == "frame":
        w = float(mask_cfg.get("frame_width", 8)); NN = float(mask_cfg.get("N", N))
        pred_frac = max(0.0, ((NN-2*w)*(NN-2*w)) / (NN*NN))

    # CSV
    mcsv = os.path.join(root, "outputs/metrics", "g20_panels.csv")
    with open(mcsv, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["panel","slope","r2","cv","amp_mean","az_flat","r_max_glob","pred_frac"])
        w.writerow(["BASE", f"{base['slope']:.6f}", f"{base['r2']:.6f}", f"{base['cv']:.6f}",
                    f"{base['amp_mean']:.6f}", f"{base['az_flat']:.6f}", r_max_glob, "1.000000"])
        w.writerow(["MASK", f"{masked['slope']:.6f}", f"{masked['r2']:.6f}", f"{masked['cv']:.6f}",
                    f"{masked['amp_mean']:.6f}", f"{masked['az_flat']:.6f}", r_max_glob, f"{pred_frac:.6f}"])

    # acceptance
    acc = M["acceptance"]
    slope_target   = float(acc["slope_target"])
    slope_tol_abs  = float(acc["slope_tol_abs"])
    r2_min         = float(acc["r2_min"])
    cv_max         = float(acc["cv_max"])
    delta_slope_max= float(acc["delta_slope_max"])
    delta_cv_max   = float(acc["delta_cv_max"])
    amp_ratio_tol  = float(acc["amp_ratio_tol_abs"])
    az_flat_max    = float(acc["az_flat_max"])

    dslope = abs(base["slope"] - masked["slope"])
    dcv    = abs(base["cv"]    - masked["cv"])
    amp_ratio = (masked["amp_mean"]/base["amp_mean"]) if (base["amp_mean"]>0) else float("inf")

    per_base_ok = (abs(base["slope"] - slope_target) <= slope_tol_abs and base["r2"] >= r2_min and base["cv"] <= cv_max and base["az_flat"] <= az_flat_max)
    per_mask_ok = (abs(masked["slope"] - slope_target) <= slope_tol_abs and masked["r2"] >= r2_min and masked["cv"] <= cv_max and masked["az_flat"] <= az_flat_max)
    inv_ok  = (dslope <= delta_slope_max) and (dcv <= delta_cv_max)
    amp_ok  = (abs(amp_ratio - pred_frac) <= amp_ratio_tol)

    passed = bool(per_base_ok and per_mask_ok and inv_ok and amp_ok)

    audit = {
        "sim": "G20_boundary_masks_v2",
        "gridN": N, "H": H, "rate_num": rn, "S": S, "outer_margin": om,
        "r_max_glob": r_max_glob,
        "base": base, "masked": masked,
        "cross": {"delta_slope": dslope, "delta_cv": dcv, "amp_ratio": amp_ratio, "pred_frac": pred_frac, "inv_ok": inv_ok, "amp_ok": amp_ok},
        "accept": acc,
        "pass": passed,
        "manifest_hash": sha256_file(mpath)
    }
    ensure_dirs(root, ["outputs/audits","outputs/run_info"])
    json_dump(os.path.join(root, "outputs", "audits", "g20_audit.json"), audit)

    result_line = ("G20_v2 PASS={p} slope_B={sb:.4f} slope_M={sm:.4f} Δslope={ds:.4f} "
                   "cv_B={cb:.4f} cv_M={cm:.4f} Δcv={dc:.4f} amp_ratio={ar:.3f} pred={pf:.3f} "
                   "az_B={ab:.4f} az_M={am:.4f}"
                   .format(p=passed, sb=base["slope"], sm=masked["slope"], ds=dslope,
                           cb=base["cv"], cm=masked["cv"], dc=dcv,
                           ar=amp_ratio, pf=pred_frac,
                           ab=base["az_flat"], am=masked["az_flat"]))
    write_text(os.path.join(root, "outputs", "run_info", "result_line.txt"), result_line)
    print(result_line)

if __name__ == "__main__":
    main()
